[webgpu] Register GQA based on graph capture#26384
Merged
Conversation
guschmue
approved these changes
Oct 28, 2025
Contributor
There was a problem hiding this comment.
Pull Request Overview
This PR enables conditional registration of the GroupQueryAttention (GQA) operator based on whether graph capture is enabled in the WebGPU execution provider. When graph capture is enabled, the operator reads total sequence length from GPU buffers instead of CPU memory, eliminating the need for a MemcpyToHost operation that was blocking graph capture support.
Key changes:
- Modified GQA kernel registration to conditionally set InputMemoryType based on graph capture status
- Updated flash attention shader templates and programs to support reading sequence length from GPU buffers
- Added validation logic to handle total_seqlen tensor when it resides on GPU during graph capture
Reviewed Changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc | Passes enable_graph_capture flag to RegisterWebGpuContribKernels |
| onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h | Adds enable_graph_capture parameter to RegisterWebGpuContribKernels signature |
| onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc | Replaces static GQA registration with conditional registration via CreateGroupQueryAttentionKernelInfo |
| onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h | Declares CreateGroupQueryAttentionKernelInfo function for conditional kernel creation |
| onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc | Implements conditional kernel registration and updates ApplyFlashAttention signature to accept seqlen_k |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template | Adds get_total_sequence_length() function that reads from either GPU buffer or uniforms based on use_seqlen_k flag |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention.h | Adds use_seqlen_k member to CopyKVCacheProgram and FlashAttentionProgram classes |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc | Implements use_seqlen_k logic in shader code generation and removes past_sequence_length uniform |
| onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h | Updates validation logic to skip CPU-specific checks when total_seqlen is on GPU |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
naomiOvad
pushed a commit
to naomiOvad/onnxruntime
that referenced
this pull request
Nov 2, 2025
This pull request enables conditionally register GQA with total_sequence_length on gpu or not. It resolves the issue that a MemcpyToHost is generated when graph capture is enabled (refer to microsoft#25868). This is the last functionality part to support graph capture in webgpu ep in ORT. The main changes ensure that when graph capture is enabled, sequence length information is read from GPU buffers instead of CPU memory, and shader code generation adapts accordingly. This enables more efficient execution and compatibility with graph-captured models. In this PR, we still get total sequence length from `seqlen_k` tensor not `total_seqlen_tensor` tensor to keep consistent with other parts. In the next PR, we can refactor all places to directly use `total_seqlen_tensor` instead of `seqlen_k` when graph capture enabled.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This pull request enables conditionally register GQA with total_sequence_length on gpu or not. It resolves the issue that a MemcpyToHost is generated when graph capture is enabled (refer to #25868). This is the last functionality part to support graph capture in webgpu ep in ORT.
The main changes ensure that when graph capture is enabled, sequence length information is read from GPU buffers instead of CPU memory, and shader code generation adapts accordingly. This enables more efficient execution and compatibility with graph-captured models.
In this PR, we still get total sequence length from
seqlen_ktensor nottotal_seqlen_tensortensor to keep consistent with other parts. In the next PR, we can refactor all places to directly usetotal_seqlen_tensorinstead ofseqlen_kwhen graph capture enabled.